"""
This file contains functions for loading data into MemGPT's archival storage.

Data can be loaded with the following command, once a load function is defined:
```
memgpt load <data-connector-type> --name <dataset-name> [ADDITIONAL ARGS]
```

"""

import uuid
from typing import Annotated, List, Optional

import typer

from memgpt import create_client
from memgpt.data_sources.connectors import DirectoryConnector

app = typer.Typer()

# NOTE: not supported due to llama-index breaking things (please reach out if you still need it)
# @app.command("index")
# def load_index(
#    name: Annotated[str, typer.Option(help="Name of dataset to load.")],
#    dir: Annotated[Optional[str], typer.Option(help="Path to directory containing index.")] = None,
#    user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
# ):
#    """Load a LlamaIndex saved VectorIndex into MemGPT"""
#    if user_id is None:
#        config = MemGPTConfig.load()
#        user_id = uuid.UUID(config.anon_clientid)
#
#    try:
#        # load index data
#        storage_context = StorageContext.from_defaults(persist_dir=dir)
#        loaded_index = load_index_from_storage(storage_context)
#
#        # hacky code to extract out passages/embeddings (thanks a lot, llama index)
#        embed_dict = loaded_index._vector_store._data.embedding_dict
#        node_dict = loaded_index._docstore.docs
#
#        # create storage connector
#        config = MemGPTConfig.load()
#        if user_id is None:
#            user_id = uuid.UUID(config.anon_clientid)
#
#        passages = []
#        for node_id, node in node_dict.items():
#            vector = embed_dict[node_id]
#            node.embedding = vector
#            # assume embedding are the same as config
#            passages.append(
#                Passage(
#                    text=node.text,
#                    embedding=np.array(vector),
#                    embedding_dim=config.default_embedding_config.embedding_dim,
#                    embedding_model=config.default_embedding_config.embedding_model,
#                )
#            )
#            assert config.default_embedding_config.embedding_dim == len(
#                vector
#            ), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(vector)}"
#
#        if len(passages) == 0:
#            raise ValueError(f"No passages found in index {dir}")
#
#        insert_passages_into_source(passages, name, user_id, config)
#    except ValueError as e:
#        typer.secho(f"Failed to load index from provided information.\n{e}", fg=typer.colors.RED)


default_extensions = ".txt,.md,.pdf"


@app.command("directory")
def load_directory(
    name: Annotated[str, typer.Option(help="Name of dataset to load.")],
    input_dir: Annotated[Optional[str], typer.Option(help="Path to directory containing dataset.")] = None,
    input_files: Annotated[List[str], typer.Option(help="List of paths to files containing dataset.")] = [],
    recursive: Annotated[bool, typer.Option(help="Recursively search for files in directory.")] = False,
    extensions: Annotated[str, typer.Option(help="Comma separated list of file extensions to load")] = default_extensions,
    user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,  # TODO: remove
    description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None,
):
    client = create_client()

    # create connector
    connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)

    # create source
    source = client.create_source(name=name)

    # load data
    try:
        client.load_data(connector, source_name=name)
    except Exception as e:
        typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
        client.delete_source(source.id)


# @app.command("webpage")
# def load_webpage(
#    name: Annotated[str, typer.Option(help="Name of dataset to load.")],
#    urls: Annotated[List[str], typer.Option(help="List of urls to load.")],
# ):
#    try:
#        from llama_index.readers.web import SimpleWebPageReader
#
#        docs = SimpleWebPageReader(html_to_text=True).load_data(urls)
#        store_docs(name, docs)
#
#    except ValueError as e:
#        typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED)


@app.command("vector-database")
def load_vector_database(
    name: Annotated[str, typer.Option(help="Name of dataset to load.")],
    uri: Annotated[str, typer.Option(help="Database URI.")],
    table_name: Annotated[str, typer.Option(help="Name of table containing data.")],
    text_column: Annotated[str, typer.Option(help="Name of column containing text.")],
    embedding_column: Annotated[str, typer.Option(help="Name of column containing embedding.")],
    user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
):
    """Load pre-computed embeddings into MemGPT from a database."""
    raise NotImplementedError
    # try:
    #    config = MemGPTConfig.load()
    #    connector = VectorDBConnector(
    #        uri=uri,
    #        table_name=table_name,
    #        text_column=text_column,
    #        embedding_column=embedding_column,
    #        embedding_dim=config.default_embedding_config.embedding_dim,
    #    )
    #    if not user_id:
    #        user_id = uuid.UUID(config.anon_clientid)

    #    ms = MetadataStore(config)
    #    source = Source(
    #        name=name,
    #        user_id=user_id,
    #        embedding_model=config.default_embedding_config.embedding_model,
    #        embedding_dim=config.default_embedding_config.embedding_dim,
    #    )
    #    ms.create_source(source)
    #    passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
    #    # TODO: also get document store

    #    # ingest data into passage/document store
    #    try:
    #        num_passages, num_documents = load_data(
    #            connector=connector,
    #            source=source,
    #            embedding_config=config.default_embedding_config,
    #            document_store=None,
    #            passage_store=passage_storage,
    #        )
    #        print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
    #    except Exception as e:
    #        typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
    #        ms.delete_source(source_id=source.id)

    # except ValueError as e:
    #    typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED)
    #    raise
